#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file

import glob
import json
import logging
import math
import os
import pytz
import pytest
import re
import shutil
import tempfile
import unittest
from typing import List, Any

import fastavro
import hamcrest as hc

from fastavro.schema import parse_schema
from fastavro import writer

import apache_beam as beam
from apache_beam import Create, schema_pb2
from apache_beam.io import avroio
from apache_beam.io import filebasedsource
from apache_beam.io import iobase
from apache_beam.io import source_test_utils
from apache_beam.io.avroio import _FastAvroSource  # For testing
from apache_beam.io.avroio import avro_schema_to_beam_schema  # For testing
from apache_beam.io.avroio import beam_schema_to_avro_schema  # For testing
from apache_beam.io.avroio import avro_union_type_to_beam_type  # For testing
from apache_beam.io.avroio import avro_dict_to_beam_row  # For testing
from apache_beam.io.avroio import beam_row_to_avro_dict  # For testing
from apache_beam.io.avroio import _create_avro_sink  # For testing
from apache_beam.io.filesystems import FileSystems
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.test_stream import TestStream
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms.display import DisplayData
from apache_beam.transforms.display_test import DisplayDataItemMatcher
from apache_beam.transforms.sql import SqlTransform
from apache_beam.transforms.userstate import CombiningValueStateSpec
from apache_beam.transforms.util import LogElements
from apache_beam.utils.timestamp import Timestamp
from apache_beam.typehints import schemas
from datetime import datetime

# Import snappy optionally; some tests will be skipped when import fails.
try:
  import snappy  # pylint: disable=import-error
except ImportError:
  snappy = None  # pylint: disable=invalid-name
  logging.warning('python-snappy is not installed; some tests will be skipped.')

RECORDS = [{
    'name': 'Thomas', 'favorite_number': 1, 'favorite_color': 'blue'
}, {
    'name': 'Henry', 'favorite_number': 3, 'favorite_color': 'green'
}, {
    'name': 'Toby', 'favorite_number': 7, 'favorite_color': 'brown'
}, {
    'name': 'Gordon', 'favorite_number': 4, 'favorite_color': 'blue'
}, {
    'name': 'Emily', 'favorite_number': -1, 'favorite_color': 'Red'
}, {
    'name': 'Percy', 'favorite_number': 6, 'favorite_color': 'Green'
}]


class AvroBase(object):

  _temp_files: List[str] = []

  def __init__(self, methodName='runTest'):
    super().__init__(methodName)
    self.RECORDS = RECORDS
    self.SCHEMA_STRING = '''
          {"namespace": "example.avro",
           "type": "record",
           "name": "User",
           "fields": [
               {"name": "name", "type": "string"},
               {"name": "favorite_number",  "type": ["int", "null"]},
               {"name": "favorite_color", "type": ["string", "null"]}
           ]
          }
          '''

  def setUp(self):
    # Reducing the size of thread pools. Without this test execution may fail in
    # environments with limited amount of resources.
    filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2

  def tearDown(self):
    for path in self._temp_files:
      if os.path.exists(path):
        os.remove(path)
    self._temp_files = []

  def _write_data(
      self,
      directory=None,
      prefix=None,
      codec=None,
      count=None,
      sync_interval=None):
    raise NotImplementedError

  def _write_pattern(self, num_files, return_filenames=False):
    assert num_files > 0
    temp_dir = tempfile.mkdtemp()

    file_name = None
    file_list = []
    for _ in range(num_files):
      file_name = self._write_data(directory=temp_dir, prefix='mytemp')
      file_list.append(file_name)

    assert file_name
    file_name_prefix = file_name[:file_name.rfind(os.path.sep)]
    if return_filenames:
      return (file_name_prefix + os.path.sep + 'mytemp*', file_list)
    return file_name_prefix + os.path.sep + 'mytemp*'

  def _run_avro_test(
      self, pattern, desired_bundle_size, perform_splitting, expected_result):
    source = _FastAvroSource(pattern)

    if perform_splitting:
      assert desired_bundle_size
      splits = [
          split
          for split in source.split(desired_bundle_size=desired_bundle_size)
      ]
      if len(splits) < 2:
        raise ValueError(
            'Test is trivial. Please adjust it so that at least '
            'two splits get generated')

      sources_info = [(split.source, split.start_position, split.stop_position)
                      for split in splits]
      source_test_utils.assert_sources_equal_reference_source(
          (source, None, None), sources_info)
    else:
      read_records = source_test_utils.read_from_source(source, None, None)
      self.assertCountEqual(expected_result, read_records)

  def test_schema_read_write(self):
    with tempfile.TemporaryDirectory() as tmp_dirname:
      path = os.path.join(tmp_dirname, 'tmp_filename')
      rows = [beam.Row(a=1, b=['x', 'y']), beam.Row(a=2, b=['t', 'u'])]
      stable_repr = lambda row: json.dumps(row._asdict())
      with TestPipeline() as p:
        _ = p | Create(rows) | avroio.WriteToAvro(path) | beam.Map(print)
      with TestPipeline() as p:
        readback = (
            p
            | avroio.ReadFromAvro(path + '*', as_rows=True)
            | beam.Map(stable_repr))
        assert_that(readback, equal_to([stable_repr(r) for r in rows]))

  @pytest.mark.xlang_sql_expansion_service
  @unittest.skipIf(
      TestPipeline().get_pipeline_options().view_as(StandardOptions).runner
      is None,
      "Must be run with a runner that supports staging java artifacts.")
  def test_avro_schema_to_beam_schema_with_nullable_atomic_fields(self):
    records = []
    records.extend(self.RECORDS)
    records.append({
        'name': 'Bruce', 'favorite_number': None, 'favorite_color': None
    })
    avro_schema = fastavro.parse_schema(json.loads(self.SCHEMA_STRING))
    beam_schema = avro_schema_to_beam_schema(avro_schema)

    with TestPipeline() as p:
      readback = (
          p
          | Create(records)
          | beam.Map(avro_dict_to_beam_row(avro_schema, beam_schema))
          | SqlTransform("SELECT * FROM PCOLLECTION")
          | beam.Map(beam_row_to_avro_dict(avro_schema, beam_schema)))
      assert_that(readback, equal_to(records))

  def test_avro_union_type_to_beam_type_with_nullable_long(self):
    union_type = ['null', 'long']
    beam_type = avro_union_type_to_beam_type(union_type)
    expected_beam_type = schema_pb2.FieldType(
        atomic_type=schema_pb2.INT64, nullable=True)
    hc.assert_that(beam_type, hc.equal_to(expected_beam_type))

  def test_avro_union_type_to_beam_type_with_string_long(self):
    union_type = ['string', 'long']
    beam_type = avro_union_type_to_beam_type(union_type)
    expected_beam_type = schemas.typing_to_runner_api(Any)
    hc.assert_that(beam_type, hc.equal_to(expected_beam_type))

  def test_avro_union_type_to_beam_type_with_record_and_null(self):
    record_type = {
        'type': 'record',
        'name': 'TestRecord',
        'fields': [{
            'name': 'field1', 'type': 'string'
        }, {
            'name': 'field2', 'type': 'int'
        }]
    }
    union_type = [record_type, 'null']
    beam_type = avro_union_type_to_beam_type(union_type)
    expected_beam_type = schema_pb2.FieldType(
        row_type=schema_pb2.RowType(
            schema=schema_pb2.Schema(
                fields=[
                    schemas.schema_field(
                        'field1',
                        schema_pb2.FieldType(atomic_type=schema_pb2.STRING)),
                    schemas.schema_field(
                        'field2',
                        schema_pb2.FieldType(atomic_type=schema_pb2.INT32))
                ])),
        nullable=True)
    hc.assert_that(beam_type, hc.equal_to(expected_beam_type))

  def test_avro_union_type_to_beam_type_with_nullable_annotated_string(self):
    annotated_string_type = {"avro.java.string": "String", "type": "string"}
    union_type = ['null', annotated_string_type]

    beam_type = avro_union_type_to_beam_type(union_type)

    expected_beam_type = schema_pb2.FieldType(
        atomic_type=schema_pb2.STRING, nullable=True)
    hc.assert_that(beam_type, hc.equal_to(expected_beam_type))

  def test_avro_union_type_to_beam_type_with_only_null(self):
    union_type = ['null']
    beam_type = avro_union_type_to_beam_type(union_type)
    expected_beam_type = schemas.typing_to_runner_api(Any)
    hc.assert_that(beam_type, hc.equal_to(expected_beam_type))

  def test_avro_union_type_to_beam_type_with_multiple_types(self):
    union_type = ['null', 'string', 'int']
    beam_type = avro_union_type_to_beam_type(union_type)
    expected_beam_type = schemas.typing_to_runner_api(Any)
    hc.assert_that(beam_type, hc.equal_to(expected_beam_type))

  def test_avro_schema_to_beam_and_back(self):
    avro_schema = fastavro.parse_schema(json.loads(self.SCHEMA_STRING))
    beam_schema = avro_schema_to_beam_schema(avro_schema)
    converted_avro_schema = beam_schema_to_avro_schema(beam_schema)
    expected_fields = json.loads(self.SCHEMA_STRING)["fields"]
    hc.assert_that(
        converted_avro_schema["fields"], hc.equal_to(expected_fields))

  def test_read_without_splitting(self):
    file_name = self._write_data()
    expected_result = self.RECORDS
    self._run_avro_test(file_name, None, False, expected_result)

  def test_read_with_splitting(self):
    file_name = self._write_data()
    expected_result = self.RECORDS
    self._run_avro_test(file_name, 100, True, expected_result)

  def test_source_display_data(self):
    file_name = 'some_avro_source'
    source = \
        _FastAvroSource(
            file_name,
            validate=False,
        )
    dd = DisplayData.create_from(source)

    # No extra avro parameters for AvroSource.
    expected_items = [
        DisplayDataItemMatcher('compression', 'auto'),
        DisplayDataItemMatcher('file_pattern', file_name)
    ]
    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))

  def test_read_display_data(self):
    file_name = 'some_avro_source'
    read = \
        avroio.ReadFromAvro(
            file_name,
            validate=False)
    dd = DisplayData.create_from(read)

    # No extra avro parameters for AvroSource.
    expected_items = [
        DisplayDataItemMatcher('compression', 'auto'),
        DisplayDataItemMatcher('file_pattern', file_name)
    ]
    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))

  def test_sink_display_data(self):
    file_name = 'some_avro_sink'
    sink = _create_avro_sink(
        file_name, self.SCHEMA, 'null', '.end', 0, None, 'application/x-avro')
    dd = DisplayData.create_from(sink)

    expected_items = [
        DisplayDataItemMatcher('schema', str(self.SCHEMA)),
        DisplayDataItemMatcher(
            'file_pattern',
            'some_avro_sink-%(shard_num)05d-of-%(num_shards)05d.end'),
        DisplayDataItemMatcher('codec', 'null'),
        DisplayDataItemMatcher('compression', 'uncompressed')
    ]

    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))

  def test_write_display_data(self):
    file_name = 'some_avro_sink'
    write = avroio.WriteToAvro(file_name, self.SCHEMA)
    write.expand(beam.PCollection(beam.Pipeline()))
    dd = DisplayData.create_from(write)
    expected_items = [
        DisplayDataItemMatcher('schema', str(self.SCHEMA)),
        DisplayDataItemMatcher(
            'file_pattern',
            'some_avro_sink-%(shard_num)05d-of-%(num_shards)05d'),
        DisplayDataItemMatcher('codec', 'deflate'),
        DisplayDataItemMatcher('compression', 'uncompressed')
    ]
    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))

  def test_read_reentrant_without_splitting(self):
    file_name = self._write_data()
    source = _FastAvroSource(file_name)
    source_test_utils.assert_reentrant_reads_succeed((source, None, None))

  def test_read_reantrant_with_splitting(self):
    file_name = self._write_data()
    source = _FastAvroSource(file_name)
    splits = [split for split in source.split(desired_bundle_size=100000)]
    assert len(splits) == 1
    source_test_utils.assert_reentrant_reads_succeed(
        (splits[0].source, splits[0].start_position, splits[0].stop_position))

  def test_read_without_splitting_multiple_blocks(self):
    file_name = self._write_data(count=12000)
    expected_result = self.RECORDS * 2000
    self._run_avro_test(file_name, None, False, expected_result)

  def test_read_with_splitting_multiple_blocks(self):
    file_name = self._write_data(count=12000)
    expected_result = self.RECORDS * 2000
    self._run_avro_test(file_name, 10000, True, expected_result)

  def test_split_points(self):
    num_records = 12000
    sync_interval = 16000
    file_name = self._write_data(count=num_records, sync_interval=sync_interval)

    source = _FastAvroSource(file_name)

    splits = [split for split in source.split(desired_bundle_size=float('inf'))]
    assert len(splits) == 1
    range_tracker = splits[0].source.get_range_tracker(
        splits[0].start_position, splits[0].stop_position)

    split_points_report = []

    for _ in splits[0].source.read(range_tracker):
      split_points_report.append(range_tracker.split_points())
    # There will be a total of num_blocks in the generated test file,
    # proportional to number of records in the file divided by syncronization
    # interval used by avro during write. Each block has more than 10 records.
    num_blocks = int(math.ceil(14.5 * num_records / sync_interval))
    assert num_blocks > 1
    # When reading records of the first block, range_tracker.split_points()
    # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
    self.assertEqual(
        split_points_report[:10],
        [(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10)

    # When reading records of last block, range_tracker.split_points() should
    # return (num_blocks - 1, 1)
    self.assertEqual(split_points_report[-10:], [(num_blocks - 1, 1)] * 10)

  def test_read_without_splitting_compressed_deflate(self):
    file_name = self._write_data(codec='deflate')
    expected_result = self.RECORDS
    self._run_avro_test(file_name, None, False, expected_result)

  def test_read_with_splitting_compressed_deflate(self):
    file_name = self._write_data(codec='deflate')
    expected_result = self.RECORDS
    self._run_avro_test(file_name, 100, True, expected_result)

  @unittest.skipIf(snappy is None, 'python-snappy not installed.')
  def test_read_without_splitting_compressed_snappy(self):
    file_name = self._write_data(codec='snappy')
    expected_result = self.RECORDS
    self._run_avro_test(file_name, None, False, expected_result)

  @unittest.skipIf(snappy is None, 'python-snappy not installed.')
  def test_read_with_splitting_compressed_snappy(self):
    file_name = self._write_data(codec='snappy')
    expected_result = self.RECORDS
    self._run_avro_test(file_name, 100, True, expected_result)

  def test_read_without_splitting_pattern(self):
    pattern = self._write_pattern(3)
    expected_result = self.RECORDS * 3
    self._run_avro_test(pattern, None, False, expected_result)

  def test_read_with_splitting_pattern(self):
    pattern = self._write_pattern(3)
    expected_result = self.RECORDS * 3
    self._run_avro_test(pattern, 100, True, expected_result)

  def test_dynamic_work_rebalancing_exhaustive(self):
    def compare_split_points(file_name):
      source = _FastAvroSource(file_name)
      splits = [
          split for split in source.split(desired_bundle_size=float('inf'))
      ]
      assert len(splits) == 1
      source_test_utils.assert_split_at_fraction_exhaustive(splits[0].source)

    # Adjusting block size so that we can perform a exhaustive dynamic
    # work rebalancing test that completes within an acceptable amount of time.
    file_name = self._write_data(count=5, sync_interval=2)

    compare_split_points(file_name)

  def test_corrupted_file(self):
    file_name = self._write_data()
    with open(file_name, 'rb') as f:
      data = f.read()

    # Corrupt the last character of the file which is also the last character of
    # the last sync_marker.
    # https://avro.apache.org/docs/current/spec.html#Object+Container+Files
    corrupted_data = bytearray(data)
    corrupted_data[-1] = (corrupted_data[-1] + 1) % 256
    with tempfile.NamedTemporaryFile(delete=False,
                                     prefix=tempfile.template) as f:
      f.write(corrupted_data)
      corrupted_file_name = f.name

    source = _FastAvroSource(corrupted_file_name)
    with self.assertRaisesRegex(ValueError, r'expected sync marker'):
      source_test_utils.read_from_source(source, None, None)

  def test_read_from_avro(self):
    path = self._write_data()
    with TestPipeline() as p:
      assert_that(p | avroio.ReadFromAvro(path), equal_to(self.RECORDS))

  def test_read_all_from_avro_single_file(self):
    path = self._write_data()
    with TestPipeline() as p:
      assert_that(
          p \
          | Create([path]) \
          | avroio.ReadAllFromAvro(),
          equal_to(self.RECORDS))

  def test_read_all_from_avro_many_single_files(self):
    path1 = self._write_data()
    path2 = self._write_data()
    path3 = self._write_data()
    with TestPipeline() as p:
      assert_that(
          p \
          | Create([path1, path2, path3]) \
          | avroio.ReadAllFromAvro(),
          equal_to(self.RECORDS * 3))

  def test_read_all_from_avro_file_pattern(self):
    file_pattern = self._write_pattern(5)
    with TestPipeline() as p:
      assert_that(
          p \
          | Create([file_pattern]) \
          | avroio.ReadAllFromAvro(),
          equal_to(self.RECORDS * 5))

  def test_read_all_from_avro_many_file_patterns(self):
    file_pattern1 = self._write_pattern(5)
    file_pattern2 = self._write_pattern(2)
    file_pattern3 = self._write_pattern(3)
    with TestPipeline() as p:
      assert_that(
          p \
          | Create([file_pattern1, file_pattern2, file_pattern3]) \
          | avroio.ReadAllFromAvro(),
          equal_to(self.RECORDS * 10))

  def test_read_all_from_avro_with_filename(self):
    file_pattern, file_paths = self._write_pattern(3, return_filenames=True)
    result = [(path, record) for path in file_paths for record in self.RECORDS]
    with TestPipeline() as p:
      assert_that(
          p \
          | Create([file_pattern]) \
          | avroio.ReadAllFromAvro(with_filename=True),
          equal_to(result))

  class _WriteFilesFn(beam.DoFn):
    """writes a couple of files with deferral."""

    COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum)

    def __init__(self, SCHEMA, RECORDS, tempdir):
      self._thread = None
      self.SCHEMA = SCHEMA
      self.RECORDS = RECORDS
      self.tempdir = tempdir

    def get_expect(self, match_updated_files):
      results_file1 = [('file1', x) for x in self.gen_records(1)]
      results_file2 = [('file2', x) for x in self.gen_records(3)]
      if match_updated_files:
        results_file1 += [('file1', x) for x in self.gen_records(2)]
      return results_file1 + results_file2

    def gen_records(self, count):
      return self.RECORDS * (count // len(self.RECORDS)) + self.RECORDS[:(
          count % len(self.RECORDS))]

    def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)):
      counter = count_state.read()
      if counter == 0:
        count_state.add(1)
        with open(FileSystems.join(self.tempdir, 'file1'), 'wb') as f:
          writer(f, self.SCHEMA, self.gen_records(2))
        with open(FileSystems.join(self.tempdir, 'file2'), 'wb') as f:
          writer(f, self.SCHEMA, self.gen_records(3))
      # convert dumb key to basename in output
      basename = FileSystems.split(element[1][0])[1]
      content = element[1][1]
      yield basename, content

  def test_read_all_continuously_new(self):
    with TestPipeline() as pipeline:
      tempdir = tempfile.mkdtemp()
      writer_fn = self._WriteFilesFn(self.SCHEMA, self.RECORDS, tempdir)
      with open(FileSystems.join(tempdir, 'file1'), 'wb') as f:
        writer(f, writer_fn.SCHEMA, writer_fn.gen_records(1))
      match_pattern = FileSystems.join(tempdir, '*')
      interval = 0.5
      last = 2

      p_read_once = (
          pipeline
          | 'Continuously read new files' >> avroio.ReadAllFromAvroContinuously(
              match_pattern,
              with_filename=True,
              start_timestamp=Timestamp.now(),
              interval=interval,
              stop_timestamp=Timestamp.now() + last,
              match_updated_files=False)
          | 'add dumb key' >> beam.Map(lambda x: (0, x))
          | 'Write files on-the-fly' >> beam.ParDo(writer_fn))
      assert_that(
          p_read_once,
          equal_to(writer_fn.get_expect(match_updated_files=False)),
          label='assert read new files results')

  def test_read_all_continuously_update(self):
    with TestPipeline() as pipeline:
      tempdir = tempfile.mkdtemp()
      writer_fn = self._WriteFilesFn(self.SCHEMA, self.RECORDS, tempdir)
      with open(FileSystems.join(tempdir, 'file1'), 'wb') as f:
        writer(f, writer_fn.SCHEMA, writer_fn.gen_records(1))
      match_pattern = FileSystems.join(tempdir, '*')
      interval = 0.5
      last = 2

      p_read_upd = (
          pipeline
          | 'Continuously read updated files' >>
          avroio.ReadAllFromAvroContinuously(
              match_pattern,
              with_filename=True,
              start_timestamp=Timestamp.now(),
              interval=interval,
              stop_timestamp=Timestamp.now() + last,
              match_updated_files=True)
          | 'add dumb key' >> beam.Map(lambda x: (0, x))
          | 'Write files on-the-fly' >> beam.ParDo(writer_fn))
      assert_that(
          p_read_upd,
          equal_to(writer_fn.get_expect(match_updated_files=True)),
          label='assert read updated files results')

  def test_sink_transform(self):
    with tempfile.NamedTemporaryFile() as dst:
      path = dst.name
      with TestPipeline() as p:
        # pylint: disable=expression-not-assigned
        p \
        | beam.Create(self.RECORDS) \
        | avroio.WriteToAvro(path, self.SCHEMA,)
      with TestPipeline() as p:
        # json used for stable sortability
        readback = \
            p \
            | avroio.ReadFromAvro(path + '*', ) \
            | beam.Map(json.dumps)
        assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))

  @unittest.skipIf(snappy is None, 'python-snappy not installed.')
  def test_sink_transform_snappy(self):
    with tempfile.NamedTemporaryFile() as dst:
      path = dst.name
      with TestPipeline() as p:
        # pylint: disable=expression-not-assigned
        p \
        | beam.Create(self.RECORDS) \
        | avroio.WriteToAvro(
            path,
            self.SCHEMA,
            codec='snappy')
      with TestPipeline() as p:
        # json used for stable sortability
        readback = \
            p \
            | avroio.ReadFromAvro(path + '*') \
            | beam.Map(json.dumps)
        assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))

  def test_writer_open_and_close(self):
    # Create and then close a temp file so we can manually open it later
    dst = tempfile.NamedTemporaryFile(delete=False)
    dst.close()

    schema = parse_schema(json.loads(self.SCHEMA_STRING))
    sink = _create_avro_sink(
        'some_avro_sink', schema, 'null', '.end', 0, None, 'application/x-avro')

    w = sink.open(dst.name)

    sink.close(w)

    os.unlink(dst.name)


class TestFastAvro(AvroBase, unittest.TestCase):
  def __init__(self, methodName='runTest'):
    super().__init__(methodName)
    self.SCHEMA = parse_schema(json.loads(self.SCHEMA_STRING))

  def _write_data(
      self,
      directory=None,
      prefix=tempfile.template,
      codec='null',
      count=len(RECORDS),
      **kwargs):
    all_records = self.RECORDS * \
      (count // len(self.RECORDS)) + self.RECORDS[:(count % len(self.RECORDS))]
    with tempfile.NamedTemporaryFile(delete=False,
                                     dir=directory,
                                     prefix=prefix,
                                     mode='w+b') as f:
      writer(f, self.SCHEMA, all_records, codec=codec, **kwargs)
      self._temp_files.append(f.name)
    return f.name


class GenerateEvent(beam.PTransform):
  @staticmethod
  def sample_data():
    return GenerateEvent()

  def expand(self, input):
    elemlist = [{'age': 10}, {'age': 20}, {'age': 30}]
    elem = elemlist
    return (
        input
        | TestStream().add_elements(
            elements=elem,
            event_timestamp=datetime(
                2021, 3, 1, 0, 0, 1, 0,
                tzinfo=pytz.UTC).timestamp()).add_elements(
                    elements=elem,
                    event_timestamp=datetime(
                        2021, 3, 1, 0, 0, 2, 0,
                        tzinfo=pytz.UTC).timestamp()).add_elements(
                            elements=elem,
                            event_timestamp=datetime(
                                2021, 3, 1, 0, 0, 3, 0,
                                tzinfo=pytz.UTC).timestamp()).add_elements(
                                    elements=elem,
                                    event_timestamp=datetime(
                                        2021, 3, 1, 0, 0, 4, 0,
                                        tzinfo=pytz.UTC).timestamp()).
        advance_watermark_to(
            datetime(2021, 3, 1, 0, 0, 5, 0,
                     tzinfo=pytz.UTC).timestamp()).add_elements(
                         elements=elem,
                         event_timestamp=datetime(
                             2021, 3, 1, 0, 0, 5, 0,
                             tzinfo=pytz.UTC).timestamp()).
        add_elements(
            elements=elem,
            event_timestamp=datetime(
                2021, 3, 1, 0, 0, 6,
                0, tzinfo=pytz.UTC).timestamp()).add_elements(
                    elements=elem,
                    event_timestamp=datetime(
                        2021, 3, 1, 0, 0, 7, 0,
                        tzinfo=pytz.UTC).timestamp()).add_elements(
                            elements=elem,
                            event_timestamp=datetime(
                                2021, 3, 1, 0, 0, 8, 0,
                                tzinfo=pytz.UTC).timestamp()).add_elements(
                                    elements=elem,
                                    event_timestamp=datetime(
                                        2021, 3, 1, 0, 0, 9, 0,
                                        tzinfo=pytz.UTC).timestamp()).
        advance_watermark_to(
            datetime(2021, 3, 1, 0, 0, 10, 0,
                     tzinfo=pytz.UTC).timestamp()).add_elements(
                         elements=elem,
                         event_timestamp=datetime(
                             2021, 3, 1, 0, 0, 10, 0,
                             tzinfo=pytz.UTC).timestamp()).add_elements(
                                 elements=elem,
                                 event_timestamp=datetime(
                                     2021, 3, 1, 0, 0, 11, 0,
                                     tzinfo=pytz.UTC).timestamp()).
        add_elements(
            elements=elem,
            event_timestamp=datetime(
                2021, 3, 1, 0, 0, 12, 0,
                tzinfo=pytz.UTC).timestamp()).add_elements(
                    elements=elem,
                    event_timestamp=datetime(
                        2021, 3, 1, 0, 0, 13, 0,
                        tzinfo=pytz.UTC).timestamp()).add_elements(
                            elements=elem,
                            event_timestamp=datetime(
                                2021, 3, 1, 0, 0, 14, 0,
                                tzinfo=pytz.UTC).timestamp()).
        advance_watermark_to(
            datetime(2021, 3, 1, 0, 0, 15, 0,
                     tzinfo=pytz.UTC).timestamp()).add_elements(
                         elements=elem,
                         event_timestamp=datetime(
                             2021, 3, 1, 0, 0, 15, 0,
                             tzinfo=pytz.UTC).timestamp()).add_elements(
                                 elements=elem,
                                 event_timestamp=datetime(
                                     2021, 3, 1, 0, 0, 16, 0,
                                     tzinfo=pytz.UTC).timestamp()).
        add_elements(
            elements=elem,
            event_timestamp=datetime(
                2021, 3, 1, 0, 0, 17, 0,
                tzinfo=pytz.UTC).timestamp()).add_elements(
                    elements=elem,
                    event_timestamp=datetime(
                        2021, 3, 1, 0, 0, 18, 0,
                        tzinfo=pytz.UTC).timestamp()).add_elements(
                            elements=elem,
                            event_timestamp=datetime(
                                2021, 3, 1, 0, 0, 19, 0,
                                tzinfo=pytz.UTC).timestamp()).
        advance_watermark_to(
            datetime(2021, 3, 1, 0, 0, 20, 0,
                     tzinfo=pytz.UTC).timestamp()).add_elements(
                         elements=elem,
                         event_timestamp=datetime(
                             2021, 3, 1, 0, 0, 20, 0,
                             tzinfo=pytz.UTC).timestamp()).advance_watermark_to(
                                 datetime(
                                     2021, 3, 1, 0, 0, 25, 0, tzinfo=pytz.UTC).
                                 timestamp()).advance_watermark_to_infinity())


class WriteStreamingTest(unittest.TestCase):
  def setUp(self):
    super().setUp()
    self.tempdir = tempfile.mkdtemp()

  def tearDown(self):
    if os.path.exists(self.tempdir):
      shutil.rmtree(self.tempdir)

  def test_write_streaming_2_shards_default_shard_name_template(
      self, num_shards=2):
    with TestPipeline() as p:
      output = (
          p
          | GenerateEvent.sample_data()
          | 'User windowing' >> beam.transforms.core.WindowInto(
              beam.transforms.window.FixedWindows(60),
              trigger=beam.transforms.trigger.AfterWatermark(),
              accumulation_mode=beam.transforms.trigger.AccumulationMode.
              DISCARDING,
              allowed_lateness=beam.utils.timestamp.Duration(seconds=0)))
      #AvroIO
      avroschema = {
          'name': 'dummy', # your supposed to be file name with .avro extension
          'type': 'record', # type of avro serilazation, there are more (see
                            # above docs)
          'fields': [ # this defines actual keys & their types
              {'name': 'age', 'type': 'int'},
          ],
        }
      output2 = output | 'WriteToAvro' >> beam.io.WriteToAvro(
          file_path_prefix=self.tempdir + "/ouput_WriteToAvro",
          file_name_suffix=".avro",
          num_shards=num_shards,
          schema=avroschema)
      _ = output2 | 'LogElements after WriteToAvro' >> LogElements(
          prefix='after WriteToAvro ', with_window=True, level=logging.INFO)

    # Regex to match the expected windowed file pattern
    # Example:
    #  ouput_WriteToAvro-[1614556800.0, 1614556805.0)-00000-of-00002.avro
    # It captures: window_interval, shard_num, total_shards
    pattern_string = (
        r'.*-\[(?P<window_start>[\d\.]+), '
        r'(?P<window_end>[\d\.]+|Infinity)\)-'
        r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.avro$')
    pattern = re.compile(pattern_string)
    file_names = []
    for file_name in glob.glob(self.tempdir + '/ouput_WriteToAvro*'):
      match = pattern.match(file_name)
      self.assertIsNotNone(
          match, f"File name {file_name} did not match expected pattern.")
      if match:
        file_names.append(file_name)
    print("Found files matching expected pattern:", file_names)
    self.assertEqual(
        len(file_names),
        num_shards,
        "expected %d files, but got: %d" % (num_shards, len(file_names)))

  def test_write_streaming_2_shards_custom_shard_name_template(
      self, num_shards=2, shard_name_template='-V-SSSSS-of-NNNNN'):
    with TestPipeline() as p:
      output = (p | GenerateEvent.sample_data())
      #AvroIO
      avroschema = {
          'name': 'dummy', # your supposed to be file name with .avro extension
          'type': 'record', # type of avro serilazation
          'fields': [ # this defines actual keys & their types
              {'name': 'age', 'type': 'int'},
          ],
        }
      output2 = output | 'WriteToAvro' >> beam.io.WriteToAvro(
          file_path_prefix=self.tempdir + "/ouput_WriteToAvro",
          file_name_suffix=".avro",
          shard_name_template=shard_name_template,
          num_shards=num_shards,
          triggering_frequency=60,
          schema=avroschema)
      _ = output2 | 'LogElements after WriteToAvro' >> LogElements(
          prefix='after WriteToAvro ', with_window=True, level=logging.INFO)

    # Regex to match the expected windowed file pattern
    # Example:
    # ouput_WriteToAvro-[2021-03-01T00-00-00, 2021-03-01T00-01-00)-
    #   00000-of-00002.avro
    # It captures: window_interval, shard_num, total_shards
    pattern_string = (
        r'.*-\[(?P<window_start>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), '
        r'(?P<window_end>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-'
        r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.avro$')
    pattern = re.compile(pattern_string)
    file_names = []
    for file_name in glob.glob(self.tempdir + '/ouput_WriteToAvro*'):
      match = pattern.match(file_name)
      self.assertIsNotNone(
          match, f"File name {file_name} did not match expected pattern.")
      if match:
        file_names.append(file_name)
    print("Found files matching expected pattern:", file_names)
    self.assertEqual(
        len(file_names),
        num_shards,
        "expected %d files, but got: %d" % (num_shards, len(file_names)))

  def test_write_streaming_2_shards_custom_shard_name_template_5s_window(
      self,
      num_shards=2,
      shard_name_template='-V-SSSSS-of-NNNNN',
      triggering_frequency=5):
    with TestPipeline() as p:
      output = (p | GenerateEvent.sample_data())
      #AvroIO
      avroschema = {
          'name': 'dummy', # your supposed to be file name with .avro extension 
          'type': 'record', # type of avro serilazation
          'fields': [ # this defines actual keys & their types
              {'name': 'age', 'type': 'int'},
          ],
        }
      output2 = output | 'WriteToAvro' >> beam.io.WriteToAvro(
          file_path_prefix=self.tempdir + "/ouput_WriteToAvro",
          file_name_suffix=".txt",
          shard_name_template=shard_name_template,
          num_shards=num_shards,
          triggering_frequency=triggering_frequency,
          schema=avroschema)
      _ = output2 | 'LogElements after WriteToAvro' >> LogElements(
          prefix='after WriteToAvro ', with_window=True, level=logging.INFO)

    # Regex to match the expected windowed file pattern
    # Example:
    #   ouput_WriteToAvro-[2021-03-01T00-00-00, 2021-03-01T00-01-00)-
    #     00000-of-00002.avro
    # It captures: window_interval, shard_num, total_shards
    pattern_string = (
        r'.*-\[(?P<window_start>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), '
        r'(?P<window_end>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-'
        r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.txt$')
    pattern = re.compile(pattern_string)
    file_names = []
    for file_name in glob.glob(self.tempdir + '/ouput_WriteToAvro*'):
      match = pattern.match(file_name)
      self.assertIsNotNone(
          match, f"File name {file_name} did not match expected pattern.")
      if match:
        file_names.append(file_name)
    print("Found files matching expected pattern:", file_names)
    # for 5s window size, the input should be processed by 5 windows with
    # 2 shards per window
    self.assertEqual(
        len(file_names),
        10,
        "expected %d files, but got: %d" % (num_shards, len(file_names)))


if __name__ == '__main__':
  logging.getLogger().setLevel(logging.INFO)
  unittest.main()
